{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fast Sign Adversary Generation Example\n",
    "\n",
    "This notebook demos find adversary example by using symbolic API and integration with Numpy\n",
    "Reference: \n",
    "\n",
    "[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining and harnessing adversarial examples.\" arXiv preprint arXiv:1412.6572 (2014).\n",
    "https://arxiv.org/abs/1412.6572"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import mxnet as mx\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "from mxnet.test_utils import get_mnist_iterator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build Network\n",
    "\n",
    "note: in this network, we will calculate softmax, gradient in numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dev = mx.cpu()\n",
    "batch_size = 100\n",
    "train_iter, val_iter = get_mnist_iterator(batch_size=batch_size, input_shape = (1,28,28))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# input\n",
    "data = mx.symbol.Variable('data')\n",
    "# first conv\n",
    "conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)\n",
    "tanh1 = mx.symbol.Activation(data=conv1, act_type=\"tanh\")\n",
    "pool1 = mx.symbol.Pooling(data=tanh1, pool_type=\"max\",\n",
    "                          kernel=(2,2), stride=(2,2))\n",
    "# second conv\n",
    "conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)\n",
    "tanh2 = mx.symbol.Activation(data=conv2, act_type=\"tanh\")\n",
    "pool2 = mx.symbol.Pooling(data=tanh2, pool_type=\"max\",\n",
    "                          kernel=(2,2), stride=(2,2))\n",
    "# first fullc\n",
    "flatten = mx.symbol.Flatten(data=pool2)\n",
    "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)\n",
    "tanh3 = mx.symbol.Activation(data=fc1, act_type=\"tanh\")\n",
    "# second fullc\n",
    "fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def Softmax(theta):\n",
    "    max_val = np.max(theta, axis=1, keepdims=True)\n",
    "    tmp = theta - max_val\n",
    "    exp = np.exp(tmp)\n",
    "    norm = np.sum(exp, axis=1, keepdims=True)\n",
    "    return exp / norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def LogLossGrad(alpha, label):\n",
    "    grad = np.copy(alpha)\n",
    "    for i in range(alpha.shape[0]):\n",
    "        grad[i, int(label[i])] -= 1.\n",
    "    return grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare useful data for the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data_shape = (batch_size, 1, 28, 28)\n",
    "arg_names = fc2.list_arguments() # 'data' \n",
    "arg_shapes, output_shapes, aux_shapes = fc2.infer_shape(data=data_shape)\n",
    "\n",
    "arg_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
    "grad_arrays = [mx.nd.zeros(shape, ctx=dev) for shape in arg_shapes]\n",
    "reqs = [\"write\" for name in arg_names]\n",
    "\n",
    "model = fc2.bind(ctx=dev, args=arg_arrays, args_grad = grad_arrays, grad_req=reqs)\n",
    "arg_map = dict(zip(arg_names, arg_arrays))\n",
    "grad_map = dict(zip(arg_names, grad_arrays))\n",
    "data_grad = grad_map[\"data\"]\n",
    "out_grad = mx.nd.zeros(model.outputs[0].shape, ctx=dev)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Init weight "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for name in arg_names:\n",
    "    if \"weight\" in name:\n",
    "        arr = arg_map[name]\n",
    "        arr[:] = mx.rnd.uniform(-0.07, 0.07, arr.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def SGD(weight, grad, lr=0.1, grad_norm=batch_size):\n",
    "    weight[:] -= lr * grad / batch_size\n",
    "\n",
    "def CalAcc(pred_prob, label):\n",
    "    pred = np.argmax(pred_prob, axis=1)\n",
    "    return np.sum(pred == label) * 1.0\n",
    "\n",
    "def CalLoss(pred_prob, label):\n",
    "    loss = 0.\n",
    "    for i in range(pred_prob.shape[0]):\n",
    "        loss += -np.log(max(pred_prob[i, int(label[i])], 1e-10))\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train a network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "num_round = 4\n",
    "train_acc = 0.\n",
    "nbatch = 0\n",
    "for i in range(num_round):\n",
    "    train_loss = 0.\n",
    "    train_acc = 0.\n",
    "    nbatch = 0\n",
    "    train_iter.reset()\n",
    "    for batch in train_iter:\n",
    "        arg_map[\"data\"][:] = batch.data[0]\n",
    "        model.forward(is_train=True)\n",
    "        theta = model.outputs[0].asnumpy()\n",
    "        alpha = Softmax(theta)\n",
    "        label = batch.label[0].asnumpy()\n",
    "        train_acc += CalAcc(alpha, label) / batch_size\n",
    "        train_loss += CalLoss(alpha, label) / batch_size\n",
    "        losGrad_theta = LogLossGrad(alpha, label)\n",
    "        out_grad[:] = losGrad_theta\n",
    "        model.backward([out_grad])\n",
    "        # data_grad[:] = grad_map[\"data\"]\n",
    "        for name in arg_names:\n",
    "            if name != \"data\":\n",
    "                SGD(arg_map[name], grad_map[name])\n",
    "        \n",
    "        nbatch += 1\n",
    "    #print(np.linalg.norm(data_grad.asnumpy(), 2))\n",
    "    train_acc /= nbatch\n",
    "    train_loss /= nbatch\n",
    "    print(\"Train Accuracy: %.2f\\t Train Loss: %.5f\" % (train_acc, train_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get pertubation by using fast sign method, check validation change.\n",
    "See that the validation set was almost entirely correct before the perturbations, but after the perturbations, it is much worse than random guessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "val_iter.reset()\n",
    "batch = val_iter.next()\n",
    "data = batch.data[0]\n",
    "label = batch.label[0]\n",
    "arg_map[\"data\"][:] = data\n",
    "model.forward(is_train=True)\n",
    "theta = model.outputs[0].asnumpy()\n",
    "alpha = Softmax(theta)\n",
    "print(\"Val Batch Accuracy: \", CalAcc(alpha, label.asnumpy()) / batch_size)\n",
    "#########\n",
    "grad = LogLossGrad(alpha, label.asnumpy())\n",
    "out_grad[:] = grad\n",
    "model.backward([out_grad])\n",
    "noise = np.sign(data_grad.asnumpy())\n",
    "arg_map[\"data\"][:] = data.asnumpy() + 0.15 * noise\n",
    "model.forward(is_train=True)\n",
    "raw_output = model.outputs[0].asnumpy()\n",
    "pred = Softmax(raw_output)\n",
    "print(\"Val Batch Accuracy after pertubation: \", CalAcc(pred, label.asnumpy()) / batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualize an example after pertubation.\n",
    "Note that the prediction is consistently incorrect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import random as rnd\n",
    "idx = rnd.randint(0, 99)\n",
    "images = data.asnumpy()  + 0.15 * noise\n",
    "plt.imshow(images[idx, :].reshape(28,28), cmap=cm.Greys_r)\n",
    "print(\"true: %d\" % label.asnumpy()[idx])\n",
    "print(\"pred: %d\" % np.argmax(pred, axis=1)[idx])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
